from functools import cached_property, reduce
from typing import List, Optional, Union
from copy import deepcopy
from collections import defaultdict
import numpy as np
import torch
import torchaudio
import torch.nn.functional as F
from hyperpyyaml import load_hyperpyyaml
from stepvocoder.cosyvoice2.cli.frontend import CosyVoiceFrontEnd
from stepvocoder.cosyvoice2.flow.flow import CausalMaskedDiffWithXvec
from stepvocoder.cosyvoice2.hifigan.generator import HiFTGenerator
from stepvocoder.cosyvoice2.bigvgan.bigvgan import BigVGAN
# from stepvocoder.cosyvoice2.utils.common import fade_in_out
import threading

"""perform fade_in_out in tensor style
"""
def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor):
    mel_overlap_len = int(window.shape[0] / 2)
    fade_in_mel = fade_in_mel.clone()
    fade_in_mel[..., :mel_overlap_len] = \
        fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
        fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
    return fade_in_mel


# torch._dynamo.config.cache_size_limit = 128
# torch._dynamo.config.accumulated_cache_size_limit = 128


"""
A wrapper for managing stream caches. 
"""
class CosyVoice_stream_impl_(torch.nn.Module):
    def __init__(self, 
                 flow: CausalMaskedDiffWithXvec,
                 hift: Union[HiFTGenerator, BigVGAN],
                 chunk_size_list: List = [15, 24, 48],  # (0.6s, 0.96s, 1.92s) 
                 mel_cache_len: int = 8,
                 n_timesteps: int = 10, # for both stream/non-stream
                 ):
        super().__init__()
        self.flow = flow
        self.hift = hift
        self.n_timesteps = n_timesteps
        # hard coded!
        # self.sample_rate = hift.sampling_rate
        self.token_lookahead = flow.pre_lookahead_len
        # stream conf
        self.mel_cache_len = mel_cache_len

        if isinstance(self.hift, BigVGAN):
            # bigvgan use left 3 frames and right 3 frames as context
            self.source_cache_len = int((mel_cache_len - 6)* 480)   # 50hz mel -> 24k wave
        elif isinstance(self.hift, HiFTGenerator):
            self.source_cache_len = int(mel_cache_len * 480)   # 50hz mel -> 24k wave
        else:
            raise ValueError(f'unsupported vocoder type {type(self.hift)}')

        self.register_buffer('speech_window', torch.from_numpy(np.hamming(2 * self.source_cache_len)), persistent=False)
        # session management
        self.speech_token_dict = defaultdict(list)
        self.chunk_size_list = chunk_size_list
        self.chunk_size_dict = {}
        self.b_first_chunk_dict = {}  # indicate if it's the first chunk of this session
        # hifigan cache
        self.hift_cache_dict = {}
        # model att/cnn cache
        self.chunk_cache_dict = {}
        self.estimator_prompt_length_dict = {}
        # speaker embedding cache
        self.spk_embedding_cache_dict = {}
        # setup lock
        self.setup_lock = threading.Lock()

    @cached_property
    def device(self):
        return next(self.hift.parameters()).device
    
    @cached_property
    def dtype(self):
        return next(self.hift.parameters()).dtype
    
    """NOTE Non-stream interface.
    """
    def token2wav_nonstream(self,
                            token: torch.Tensor,
                            prompt_token: torch.Tensor,
                            prompt_feat: torch.Tensor,
                            embedding: torch.Tensor,
                            ):
        def _make_len(ts:torch.Tensor):
            return torch.tensor([ts.shape[1]], dtype=torch.long, device=ts.device)
        # [02, 02, 06, 06, 06] -> [[02, 02, PAD], [06, 06, 06]]

        token = self._reshape(
            token.squeeze().tolist()
        ).unsqueeze(0)
        prompt_token = self._reshape(
            prompt_token.squeeze().tolist()
        ).unsqueeze(0)
        # align prompt mel
        prompt_feat = F.interpolate(
            prompt_feat.transpose(1, 2), 
            size=prompt_token.shape[1]*2, 
            mode='nearest'
        ).transpose(1, 2)
        
        token, prompt_token, prompt_feat, embedding = map(
            lambda ts: ts.to(self.device),
            (token, prompt_token, prompt_feat, embedding),
        )
        # inference flow
        mel = self.flow.inference(
            token, 
            _make_len(token),
            prompt_token,
            _make_len(prompt_token),
            prompt_feat.to(self.dtype),
            _make_len(prompt_feat),
            embedding.to(self.dtype),
            self.n_timesteps,
        )
        # inference vocoder
        with torch.no_grad():
            if isinstance(self.hift, BigVGAN):
                mel = torch.nn.functional.pad(mel, (3,3), mode='reflect')                                                                                                                                                                                                                     
                speech = self.hift.inference(mel).squeeze(0) # [1,1,T] -> [1,T]
            elif isinstance(self.hift, HiFTGenerator):
                speech, _ = self.hift.inference(mel)
            else:
                raise ValueError(f'unsupported vocoder type {type(self.hift)}')
        speech = speech.cpu().to(torch.float32)
        return speech
    
    """NOTE Internal method, do not call this method!
    Handle device & dtype transfer.
    """
    def _setup_cache(self,
                     token: torch.Tensor,
                     mel: torch.Tensor,
                     spk: torch.Tensor,
                     session_id: str,
                     ):
        # att/cnn-cache
        with self.setup_lock:
            cache = self.flow.setup_cache(
                token.to(self.device), 
                mel.to(self.device, self.dtype),
                spk.to(self.device, self.dtype),
                self.n_timesteps,
            )
            # 对 cache dict 里的每个 tensor 做 clone().detach()
            cache = {k: (v.clone().detach() if isinstance(v, torch.Tensor) else v) for k, v in cache.items()}
            self.chunk_cache_dict[session_id] = cache
            self.estimator_prompt_length_dict[session_id] = mel.shape[1]
            self.b_first_chunk_dict[session_id] = True
            # spk embedding
            self.spk_embedding_cache_dict[session_id] = spk.to(self.device, self.dtype).clone()
            # hift cache
            self.hift_cache_dict[session_id] = dict(
                mel = torch.zeros(1, mel.shape[2], 0, device=self.device, dtype=self.dtype), 
                source = torch.zeros(1, 1, 0, device=self.device, dtype=self.dtype),
                speech = torch.zeros(1, 0, device=self.device, dtype=self.dtype),
            )
            return 

    """NOTE Internal method, do not call this method!
    Handle device transfer.
    """
    def _token2wav_stream(self,
                          token: torch.Tensor,
                          session_id: str,
                          last_chunk: bool,
                          ):
        
        assert session_id in self.chunk_cache_dict, 'call setup_cache first to obtain cache'
        # fetch cache & speaker embedding
        cache = self.chunk_cache_dict[session_id]
        embedding = self.spk_embedding_cache_dict[session_id]
        # inference this chunk
        mel, new_cache = self.flow.inference_chunk(
            token.to(self.device), # int64
            embedding,
            cache,
            last_chunk,
            self.n_timesteps,
        )
        # NOTE(sfy) truncate attention cache (prompt_length + 2s left context)
        left_context_length = int(2 * 48)
        estimator_att_cache = new_cache['estimator_att_cache']
        prompt_length = self.estimator_prompt_length_dict[session_id]
        if estimator_att_cache.shape[4] > (prompt_length + left_context_length):
            new_cache['estimator_att_cache'] = torch.cat([
                estimator_att_cache[:, :, :, :, :left_context_length],
                estimator_att_cache[:, :, :, :, -prompt_length:],
            ], dim=4)

        self.chunk_cache_dict[session_id] = {k: v.clone().detach() for k, v in new_cache.items()}
        # vocoder cache
        hift_cache_mel = self.hift_cache_dict[session_id]['mel']
        hift_cache_source = self.hift_cache_dict[session_id]['source']
        hift_cache_speech = self.hift_cache_dict[session_id]['speech']
        mel = torch.concat([hift_cache_mel, mel], dim=2)
        # inference vocoder
        with torch.no_grad():
            if isinstance(self.hift, BigVGAN):
                if self.b_first_chunk_dict[session_id] and mel.shape[2] > 0:
                    print(f'[INFO] first chunk mel len: {mel.shape[2]}')
                    self.b_first_chunk_dict[session_id] = False
                    mel = F.pad(mel, (3,0), mode='reflect')
                if last_chunk:
                    mel = F.pad(mel, (0,3), mode='reflect')
                speech = self.hift.inference(mel).squeeze(0) # [1,1,T] -> [1,T]
                source = torch.zeros(1, 1, 0, device=self.device, dtype=self.dtype) # dummy source
            elif isinstance(self.hift, HiFTGenerator):
                speech, source = self.hift.inference(mel, hift_cache_source)
        # overlap speech smooth
        if hift_cache_speech.shape[-1] > 0:
            speech = fade_in_out(speech, hift_cache_speech, self.speech_window)
        # update vocoder cache
        self.hift_cache_dict[session_id] = dict(
            mel = mel[..., -self.mel_cache_len:].clone().detach(),
            source = source[:, :, -self.source_cache_len:].clone().detach(),
            speech = speech[:, -self.source_cache_len:].clone().detach(),
        )
        if not last_chunk:
            speech = speech[:, :-self.source_cache_len]
        return speech.cpu().to(torch.float32)

    @staticmethod
    def _reshape(mix_seq: List[int])->torch.Tensor:
        # assert len(mix_seq)%5 == 0, len(mix_seq)
        # NOTE add padding to avoid assert error 
        # (don't care the final speech as it's wrong anyway)
        if len(mix_seq)%5 > 0:
            pad_len = 5-(len(mix_seq)%5)
            mix_seq += [0, 0, 0, 1024, 1024, 1024][-pad_len:]

        num_groups = len(mix_seq) // 5
        vq02 = reduce(
            lambda x, y: x+y, 
            [mix_seq[i*5: i*5+2] + [1024] for i in range(num_groups)]
        )
        vq06 = reduce(
            lambda x, y: x+y, 
            [mix_seq[i*5+2: i*5+5] for i in range(num_groups)]
        )
        vq0206 = torch.stack([
            torch.tensor(vq02, dtype=torch.long),
            torch.tensor(vq06, dtype=torch.long)-1024+1025,
        ], dim=1)
        return vq0206

    """NOTE Stream interface. Called whenever one token is generated.
    NOTE(sfy) not need to transfer device or dtype

    This is a specialized version for vq0206, we change the mixed sequence to time-aligned sequence.
    eg.: [02, 02, 06, 06, 06] -> [[02, 02, PAD], [06, 06, 06]]
    """
    def token2wav_stream(self,
                         token: List[int], # vq0206 mixed seq tokens
                         prompt_token: torch.Tensor,
                         prompt_feat: torch.Tensor,
                         embedding: torch.Tensor,
                         session_id: str,
                         last_chunk: bool,
                         )->Optional[torch.Tensor]:
        # FIXME hard coded
        def _mixed_len(l:int):
            return (l // 3) * 5

        # init chunk size tracking
        if session_id not in self.chunk_size_dict:
            self.chunk_size_dict[session_id] = deepcopy(self.chunk_size_list)
        # add token
        self.speech_token_dict[session_id].extend(token)
        # waiting to setup cache
        mix_token_lookahead_len = _mixed_len(self.token_lookahead)
        if session_id not in self.chunk_cache_dict:
            if len(self.speech_token_dict[session_id]) >= mix_token_lookahead_len:
                # [02, 02, 06, 06, 06] -> [[02, 02, PAD], [06, 06, 06]]
                lookahead_token = self._reshape(
                    self.speech_token_dict[session_id][:mix_token_lookahead_len]
                ).unsqueeze(0)   # (1, t, 2)
                prompt_token = self._reshape(
                    prompt_token.squeeze().tolist()
                ).unsqueeze(0)
                # align prompt mel
                prompt_feat = F.interpolate(
                    prompt_feat.transpose(1, 2), 
                    size=prompt_token.shape[1]*2, 
                    mode='nearest'
                ).transpose(1, 2)
                self._setup_cache(
                    torch.cat([prompt_token, lookahead_token], dim=1),
                    prompt_feat,
                    embedding,
                    session_id,
                )
            return None
        
        # deal with remaining tokens
        if last_chunk:
            this_token = self.speech_token_dict[session_id]
        else:
        # cut to one chunk
            this_token = None
            mix_token_chunk_len = _mixed_len(self.chunk_size_dict[session_id][0])
            if len(self.speech_token_dict[session_id]) >= (mix_token_chunk_len+mix_token_lookahead_len):
                this_token = self.speech_token_dict[session_id][:(mix_token_chunk_len+mix_token_lookahead_len)]            
                self.speech_token_dict[session_id] = self.speech_token_dict[session_id][mix_token_chunk_len:]
        # go synthesis
        if this_token is not None:
            # [02, 02, 06, 06, 06] -> [[02, 02, PAD], [06, 06, 06]]
            this_token = self._reshape(this_token).unsqueeze(0)
            this_speech = self._token2wav_stream(
                this_token,
                session_id,
                last_chunk,
            )
            # update chunk size
            if len(self.chunk_size_dict[session_id]) > 1:
                self.chunk_size_dict[session_id].pop(0)
        else:
            this_speech = None
        # clear all caches
        if last_chunk:
            self.clean_up(session_id)
        return this_speech

    def clean_up(self, session_id: str):
        self.chunk_size_dict.pop(session_id, None)
        self.hift_cache_dict.pop(session_id, None)
        self.chunk_cache_dict.pop(session_id, None)
        self.estimator_prompt_length_dict.pop(session_id, None)
        self.spk_embedding_cache_dict.pop(session_id, None)
        self.speech_token_dict.pop(session_id, None)
        torch.cuda.empty_cache()


"""Keep compatible with cosyvoice1
"""
class CosyVoice:
    def __init__(self, 
                 model_dir:str, 
                 chunk_size_list: List = [15, 24, 48],  # (0.6s, 0.96s, 1.92s) 
                 mel_cache_len: int = 8,
                 n_timesteps: int = 10,
                 enable_cuda_graph: bool = True,
                 dtype=torch.float32,
                 ):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = dtype
        # initiate streaming wrapper
        self.model_dir = model_dir
        with open("{}/cosyvoice.yaml".format(model_dir), "r") as f:
            configs = load_hyperpyyaml(f)
            flow, hift = configs['flow'], configs['hift']
            mel_conf = configs['mel_conf']
        flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location='cpu'))
        flow = flow.eval()
        hift.load_state_dict(torch.load(f"{model_dir}/hift.pt", map_location='cpu'))
        hift = hift.eval()
        cosy_impl = CosyVoice_stream_impl_(flow, hift, chunk_size_list, mel_cache_len, n_timesteps)
        self.cosy_impl = cosy_impl.to(self.device, self.dtype)
        if enable_cuda_graph:
            self.cosy_impl.flow.scatter_cuda_graph(enable_cuda_graph)
            self.cosy_impl.hift._init_cuda_graph()
        # feature frontend
        self.frontend = CosyVoiceFrontEnd(
            mel_conf,
            campplus_model='{}/campplus.onnx'.format(model_dir),
            speech_tokenizer_model='{}/speech_tokenizer_v1.onnx'.format(model_dir),
        )
    
    # Just proxy
    def token2wav_nonstream(self,
                            token: torch.Tensor,    # vq0206 mixed seq
                            prompt_token: torch.Tensor,
                            prompt_feat: torch.Tensor,
                            embedding: torch.Tensor,
                            )->torch.Tensor:
        return self.cosy_impl.token2wav_nonstream(
            token,
            prompt_token,
            prompt_feat,
            embedding,
        )
    
    # Just proxy
    def token2wav_stream(self,
                         token: List[int], # vq0206 mixed seq tokens
                         prompt_token: torch.Tensor,
                         prompt_feat: torch.Tensor,
                         embedding: torch.Tensor,
                         session_id: str,
                         last_chunk: bool,
                         )->Optional[torch.Tensor]:
        return self.cosy_impl.token2wav_stream(
            token,
            prompt_token,
            prompt_feat,
            embedding,
            session_id,
            last_chunk,
        )

    def clean_up(self, session_id: str):
        self.cosy_impl.clean_up(session_id)
